#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Sulfolane diffusion–clustering model: reference implementation (Supplementary Code)

This script provides two complementary implementations used to generate figures:

1) Nonlinear PDE solver (finite-difference, conservative flux form) for spatial profiles
   -> Figures 2, S1, S2 style outputs (profiles c1(x,t))

2) Semi-analytical time-response solver using Duhamel's principle (numerical convolution)
   -> Figure 3 style outputs (time traces c1(t) at fixed depths)

Notes
-----
- This code is written as a clear, reproducible reference implementation (not optimized).
- All parameters are collected near the top of the file.
- The PDE is advanced with explicit time stepping and an adaptive timestep (CFL-type) based on
  the maximum effective diffusivity in the domain.
- Clustering is treated via an auxiliary conserved variable S = u + Lambda u^n (here n=2),
  which is inverted analytically for u each step (quadratic formula for n=2).

Dependencies
------------
numpy, matplotlib, pillow (PIL)

Usage
-----
python sulfolane_model_reference.py
"""

import math
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image


# -------------------------
# Global / default parameters
# -------------------------

# Geometry / domain
XMAX_CM = 2.0            # domain depth into porous matrix [cm]
XMAX = XMAX_CM / 100.0   # [m]

# Porous medium
phi = 0.10               # porosity [-]
tau = 3.0                # tortuosity [-]
rho_b = 2000.0           # bulk density [kg/m^3]

# Clustering model (dimerization)
n = 2                    # cluster order (n=2 -> dimers)
Dc_ratio = 0.05          # Dc/D1 [-] (clusters diffuse more slowly)

# Boundary dilution/leaching timescale
t_d_days = 30.0
t_d = t_d_days * 24 * 3600.0   # [s]

# Temperature dependence of molecular diffusion D1(T)
T_ref = 303.0
D1_ref = 1.0e-9          # [m^2/s] (dilute aqueous estimate at T_ref)
E_D = 18e3               # [J/mol]
R_gas = 8.314            # [J/mol/K]

# Plot/export
DPI_PNG = 300
DPI_TIFF = 600


# -------------------------
# Helper functions
# -------------------------

def D1_of_T(T):
    """Arrhenius temperature dependence for molecular diffusivity in water."""
    return D1_ref * np.exp(-(E_D / R_gas) * (1.0 / T - 1.0 / T_ref))


def De_from_D1(D1):
    """Effective diffusivity in porous medium: De = (phi/tau) D1."""
    return (phi / tau) * D1


def retardation_factor(Kd):
    """Equilibrium sorption retardation factor Rf = 1 + rho_b Kd / phi."""
    return 1.0 + (rho_b * Kd / phi)


# =============================================================================
# 1) Nonlinear PDE solver for spatial profiles (finite difference, conservative flux)
# =============================================================================

def solve_profiles_nonlinear(
    T=303.0,
    Kd=1e-5,
    Ka=0.05,
    L_ref=0.01,
    Nx=201,
    safety=0.45,
    t_days_list=(1, 7, 30, 90, 180),
    return_dimensionless=True,
):
    """
    Solve the nonlinear PDE for u(x, t) = c1/c_ref over x in [0, XMAX].

    PDE (dimensionless form used internally)
    ----------------------------------------
    Let u = c1/c_ref.
    Define effective monomer and cluster diffusivities in the porous medium:
        De1 = (phi/tau) D1(T)
        Dec = (phi/tau) Dc, with Dc = Dc_ratio * D1(T)

    Retardation (equilibrium sorption):
        Rf = 1 + rho_b Kd / phi

    Dimensionless coordinate xi = x / L_ref and dimensionless time theta:
        theta = (De1 / (Rf L_ref^2)) t

    Boundary condition (decaying source at fracture wall):
        u(0, theta) = exp(-Da_d * theta),  where Da_d = Rf L_ref^2 / (De1 t_d)
        u(xi_max, theta) = 0
        u(xi, 0) = 0

    Clustering is represented via an auxiliary "conserved" variable:
        S(u) = u + Lambda u^n  (here n=2)
      with Lambda = n Ka / Rf  (for c_ref scaling = 1).

    Flux:
        F = A(u) du/dxi
      with A(u) = 1 + Gamma u, where Gamma = n^2 Ka (Dec/De1) = n^2 Ka Dc_ratio.

    Conservative form (dimensionless):
        dS/dtheta = d/dxi ( A(u) du/dxi )

    Discretization
    --------------
    Uniform grid in xi. Central differences for gradients; explicit Euler in theta.
    Adaptive timestep dt <= safety * dxi^2 / max(A(u)).

    Returns
    -------
    xi, profiles_dict, info_dict
    """
    D1 = float(D1_of_T(T))
    De1 = float(De_from_D1(D1))
    Dec = float(De_from_D1(Dc_ratio * D1))
    Rf = float(retardation_factor(Kd))

    Lambda = (n * Ka) / Rf
    Gamma = (n * n * Ka) * (Dec / De1)
    Da_d = (Rf * (L_ref ** 2)) / (De1 * t_d)

    xi_max = XMAX / L_ref
    xi = np.linspace(0.0, xi_max, Nx)
    dxi = xi[1] - xi[0]

    def S_of_u(u):
        return u + Lambda * (u ** n)

    def u_of_S(S):
        if Lambda == 0.0:
            return S
        disc = np.maximum(1.0 + 4.0 * Lambda * S, 0.0)
        return (-1.0 + np.sqrt(disc)) / (2.0 * Lambda)

    def A_of_u(u):
        return 1.0 + Gamma * u

    u = np.zeros_like(xi)
    S = S_of_u(u)
    theta = 0.0

    t_days = np.array(list(t_days_list), dtype=float)
    t_seconds = t_days * 24 * 3600.0
    theta_targets = (De1 * t_seconds) / (Rf * (L_ref ** 2))

    profiles = {}
    k = 0
    theta_final = float(theta_targets[-1])

    while theta < theta_final - 1e-12:
        u0 = math.exp(-Da_d * theta)
        u[0] = u0
        u[-1] = 0.0
        S[0] = S_of_u(u0)
        S[-1] = 0.0

        maxA = float(np.max(A_of_u(u)))
        dt = safety * (dxi ** 2) / max(maxA, 1e-12)
        if theta + dt > theta_final:
            dt = theta_final - theta

        uL = u[:-1]
        uR = u[1:]
        A_half = 0.5 * (A_of_u(uL) + A_of_u(uR))
        F_half = A_half * (uR - uL) / dxi

        divF = np.zeros_like(u)
        divF[1:-1] = (F_half[1:] - F_half[:-1]) / dxi

        S = S + dt * divF
        theta_new = theta + dt

        S[0] = S_of_u(math.exp(-Da_d * theta_new))
        S[-1] = 0.0

        u = u_of_S(S)
        np.clip(u, 0.0, 1.0, out=u)
        S = S_of_u(u)

        theta = theta_new

        while k < len(theta_targets) and theta >= theta_targets[k] - 1e-12:
            profiles[float(t_days[k])] = u.copy()
            k += 1

    info = dict(D1=D1, De1=De1, Dec=Dec, Rf=Rf, Lambda=Lambda, Gamma=Gamma, Da_d=Da_d)

    if return_dimensionless:
        return xi, profiles, info
    else:
        x = xi * L_ref
        return x, profiles, info


def plot_profiles(xgrid, profiles, title, outbase, use_dimensionless=True):
    """Plot stored profiles u(x) at several times (lines only)."""
    plt.figure(figsize=(6.0, 5.0))
    for tday, u in profiles.items():
        plt.plot(xgrid, u, linewidth=2.0, label=f"{tday:g} d")
    plt.ylim(0, 1.05)
    plt.xlabel(r"$\xi=x/L_{\mathrm{ref}}$" if use_dimensionless else "x (m)")
    plt.ylabel(r"$c_1/c_{\mathrm{ref}}$")
    plt.title(title)
    plt.legend(frameon=False)
    plt.tight_layout()

    png = f"{outbase}.png"
    tif = f"{outbase}.tiff"
    plt.savefig(png, dpi=DPI_PNG, bbox_inches="tight")
    plt.close()
    Image.open(png).convert("RGB").save(tif, format="TIFF", compression="tiff_lzw", dpi=(DPI_TIFF, DPI_TIFF))
    return png, tif


# =============================================================================
# 2) Semi-analytical time traces (Duhamel convolution) for Figure 3
# =============================================================================

def c_boundary(t):
    """Boundary concentration at fracture wall: exp(-t/t_d)."""
    return np.exp(-t / t_d)


def duhamel_convolution(x, t, Dapp, ns=2000):
    """
    Duhamel convolution for diffusion in a semi-infinite domain with time-dependent boundary.
    """
    if t <= 0:
        return 0.0
    u = np.linspace(1e-9, t, ns)  # u = t - s
    s = t - u
    kernel = (x / (2.0 * np.sqrt(np.pi * Dapp * (u ** 3)))) * np.exp(-(x * x) / (4.0 * Dapp * u))
    return float(np.trapz(c_boundary(s) * kernel, u))


def effective_apparent_diffusivity(T, Kd, Ka):
    """
    Apparent diffusivity for the linearized time-trace model:
        Dapp = De1 / (Rf * psi)
    with psi = 1 + Gamma and Gamma = n^2 Ka * Dc_ratio.
    """
    D1 = float(D1_of_T(T))
    De1 = float(De_from_D1(D1))
    Rf = float(retardation_factor(Kd))
    Gamma = (n * n * Ka) * Dc_ratio
    psi = 1.0 + Gamma
    Dapp = De1 / (Rf * psi)
    return Dapp, dict(D1=D1, De1=De1, Rf=Rf, Gamma=Gamma, psi=psi)


def time_traces(chem_case, depths_cm=(0.1, 0.5, 1.0), Temps=(293.0, 303.0, 313.0), tmax_days=180.0, Nt=400, ns_conv=1800):
    """Compute smooth time traces (lines) for a set of depths and temperatures."""
    t_days = np.linspace(0.1, tmax_days, Nt)
    t_s = t_days * 24 * 3600.0

    traces = {}
    meta = {}
    for dcm in depths_cm:
        x = dcm / 100.0
        for T in Temps:
            Dapp, info = effective_apparent_diffusivity(T, chem_case["Kd"], chem_case["Ka"])
            cvals = np.array([duhamel_convolution(x, tt, Dapp, ns=ns_conv) for tt in t_s], dtype=float)
            traces[(dcm, T)] = cvals
            meta[(dcm, T)] = dict(Dapp=Dapp, **info)

    return t_days, traces, meta


def plot_time_traces(t_days, traces, title, outbase, depths_cm=(0.1, 0.5, 1.0), Temps=(293.0, 303.0, 313.0)):
    """Plot time traces as three stacked panels (one per depth), lines only."""
    fig, axes = plt.subplots(3, 1, figsize=(5.5, 6.5), sharex=True)
    for j, dcm in enumerate(depths_cm):
        ax = axes[j]
        for T in Temps:
            ax.plot(t_days, traces[(dcm, T)], linewidth=2.0, label=f"{T:.0f} K")
        ax.set_ylim(0, 1.05)
        ax.set_ylabel(f"x={dcm:g} cm\n$c_1/c_{{ref}}$")
        if j == 0:
            ax.set_title(title)
        if j == 1:
            ax.legend(frameon=False, ncols=3, fontsize=9, loc="upper right")
    axes[-1].set_xlabel("Time (days)")
    fig.tight_layout()

    png = f"{outbase}.png"
    tif = f"{outbase}.tiff"
    fig.savefig(png, dpi=DPI_PNG, bbox_inches="tight")
    plt.close(fig)
    Image.open(png).convert("RGB").save(tif, format="TIFF", compression="tiff_lzw", dpi=(DPI_TIFF, DPI_TIFF))
    return png, tif


# =============================================================================
# Demonstration / reproducibility run
# =============================================================================

def main():
    # Chemistry cases (as used in the manuscript)
    low_salt = dict(Kd=1e-5, Ka=0.05)
    high_sulf = dict(Kd=1e-4, Ka=0.5)

    # --- Nonlinear PDE spatial profiles (example) ---
    xi, prof_low, info_low = solve_profiles_nonlinear(T=303.0, **low_salt, Nx=201, t_days_list=(1, 7, 30, 90, 180))
    plot_profiles(xi, prof_low, title="Profiles (low salt, 303 K)", outbase="profiles_low_salt_303K", use_dimensionless=True)

    xi, prof_high, info_high = solve_profiles_nonlinear(T=303.0, **high_sulf, Nx=201, t_days_list=(1, 7, 30, 90, 180))
    plot_profiles(xi, prof_high, title="Profiles (high sulfate, 303 K)", outbase="profiles_high_sulf_303K", use_dimensionless=True)

    # --- Semi-analytical time traces (example) ---
    t_days, traces_low, meta_low = time_traces(low_salt, Nt=400, ns_conv=1800)
    plot_time_traces(t_days, traces_low, title="Figure 3a: Low salt (lines)", outbase="Figure3a_lines_only_demo")

    t_days, traces_high, meta_high = time_traces(high_sulf, Nt=400, ns_conv=1800)
    plot_time_traces(t_days, traces_high, title="Figure 3b: High sulfate (lines)", outbase="Figure3b_lines_only_demo")

    # Print key parameters for transparency
    print("=== Nonlinear PDE example (low salt, 303 K) ===")
    for k, v in info_low.items():
        print(f"{k:7s}: {v:.3e}")
    print("=== Semi-analytical example (low salt, depth=0.5 cm, 303 K) ===")
    sample = meta_low[(0.5, 303.0)]
    for k in ["D1", "De1", "Rf", "Gamma", "psi", "Dapp"]:
        print(f"{k:5s}: {sample[k]:.3e}")

    print("\nExample outputs written to the current working directory (PNG + TIFF).")


if __name__ == "__main__":
    main()
